{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import jieba\n",
    "import torch\n",
    "import tqdm\n",
    "import nltk\n",
    "import pickle\n",
    "import os\n",
    "import regex as re\n",
    "import collections\n",
    "from opencc import OpenCC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_pretrained_bert import BertTokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First load your data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filenames = os.listdir(\"data\")\n",
    "\n",
    "# df_list = []\n",
    "\n",
    "# for filename in filenames:\n",
    "#     filename = \"data/\" + filename\n",
    "#     df = pd.read_csv(filename)\n",
    "#     df_list.append(df)\n",
    "# df = pd.concat(df_list, axis=0, ignore_index=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Some Data Cleaning Functions\n",
    "\n",
    "The minimal is that you need **filterPunctuation** for you dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getPercent(x):\n",
    "    return [float(item.split(\"=\")[1][:-1]) for item in x]\n",
    "\n",
    "# This one is the most imporant\n",
    "def filterPunctuation(x):\n",
    "    x = re.sub(r'[‘’]', \"'\", x)\n",
    "    x = re.sub(r'[“”]', '\"', x)\n",
    "    x = re.sub(r'[…]', '...', x)\n",
    "    x = re.sub(r'[—]', '-', x)\n",
    "    return x\n",
    "\n",
    "def filtering(x):\n",
    "    pattern = r\"[^\\u4e00-\\u9fff0-9A-Za-z\\s+\\.\\!\\/_,：:;-|$%^*()+\\\"\\'+——！，。？、《》“”~@#￥%…&*（）]+\"\n",
    "    return re.sub(pattern, \"\", x)\n",
    "\n",
    "def removeURL(x):\n",
    "    return re.sub(r'https?:\\/\\/[A-Za-z0-9.\\/\\-]*', '', x)\n",
    "\n",
    "cc = OpenCC('t2s')\n",
    "\n",
    "def clean_data(x):\n",
    "    x = filterPunctuation(x.strip()).replace(\"\\xa0\", \"\").replace(\" \", \"\").lower()\n",
    "    x = cc.convert(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_clean = pd.read_json(\"clean_data/clean_data.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "contents = [clean_data(item[0]) for item in df_clean[\"summarization\"]]\n",
    "origin_titles = [clean_data(item) for item in df_clean[\"origin_title\"]]\n",
    "modified_titles = [clean_data(item) for item in df_clean[\"third_title\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_indices = []\n",
    "\n",
    "for idx in range(len(contents)):\n",
    "    if len(contents[idx]) < 100:\n",
    "        continue\n",
    "    \n",
    "    if origin_titles[idx] == modified_titles[idx]:\n",
    "        continue\n",
    "        \n",
    "    all_indices.append(idx)\n",
    "\n",
    "contents = [contents[idx] for idx in all_indices]\n",
    "origin_titles = [origin_titles[idx] for idx in all_indices]\n",
    "modified_titles = [modified_titles[idx] for idx in all_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_save = pd.DataFrame({\"contents\":contents, \n",
    "                          \"origin_titles\":origin_titles, \n",
    "                          \"modified_titles\":modified_titles})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "split your data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = np.random.rand(len(df_save)) < 0.8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save your data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_save[mask].to_json(\"clean_data/train.json\", force_ascii=False)\n",
    "df_save[~mask].to_json(\"clean_data/dev.json\", force_ascii=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## If you finish the above code, start from here!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = pd.read_json(\"clean_data/train.json\")\n",
    "df_val = pd.read_json(\"clean_data/dev.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "define a tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-chinese\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Turn your data into tensors: input_ids, mask, and token types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "287958b8590748a09f0b7cd644b01a47",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=20476), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_data = []\n",
    "\n",
    "encoder_max = 220\n",
    "decoder_max = 40\n",
    "\n",
    "for idx in tqdm.tqdm_notebook(range(len(df_train))):\n",
    "    \n",
    "    content = df_train.iloc[idx]['contents']\n",
    "    origin = df_train.iloc[idx]['origin_titles']\n",
    "    third = df_train.iloc[idx]['modified_titles']\n",
    "    \n",
    "    content =  tokenizer.tokenize(content)\n",
    "    origin = [\"[CLS]\"] + tokenizer.tokenize(origin) + [\"[SEP]\"]\n",
    "    \n",
    "    encoder_input = origin + content\n",
    "    encoder_input = encoder_input[:encoder_max-1] + [\"[SEP]\"]\n",
    "    encoder_type_ids = [0] * len(origin) + [1] * (len(encoder_input) - len(origin))\n",
    "\n",
    "    # modified title\n",
    "    third = [\"[CLS]\"] + tokenizer.tokenize(third)\n",
    "    third = third[:decoder_max-1] + [\"[SEP]\"]\n",
    "\n",
    "    # mask\n",
    "    mask_encoder_input = [1] * len(encoder_input)\n",
    "    mask_third = [1] * len(third)\n",
    "    \n",
    "    # conver to ids\n",
    "    encoder_input = tokenizer.convert_tokens_to_ids(encoder_input)\n",
    "    third = tokenizer.convert_tokens_to_ids(third)\n",
    "    \n",
    "    # padding\n",
    "    encoder_input += [0] * (encoder_max - len(encoder_input))\n",
    "    third += [0] * (decoder_max - len(third))\n",
    "    \n",
    "    mask_encoder_input += [0] * (encoder_max - len(mask_encoder_input))\n",
    "    mask_third += [0] * (decoder_max - len(mask_third))\n",
    "    \n",
    "    encoder_type_ids += [0] * (encoder_max - len(encoder_type_ids))\n",
    "    \n",
    "    # modified type ids\n",
    "    third_type_ids = torch.zeros(len(mask_third)).long()\n",
    "\n",
    "    # turn into tensor\n",
    "    encoder_input = torch.LongTensor(encoder_input)\n",
    "    third = torch.LongTensor(third)\n",
    "    \n",
    "    mask_encoder_input = torch.LongTensor(mask_encoder_input)\n",
    "    mask_third = torch.LongTensor(mask_third)\n",
    "    \n",
    "    encoder_type_ids = torch.LongTensor(encoder_type_ids)\n",
    "    third_type_ids = third_type_ids\n",
    "    \n",
    "    train_data.append((encoder_input, \n",
    "                       third, \n",
    "                       mask_encoder_input, \n",
    "                       mask_third,\n",
    "                       encoder_type_ids,\n",
    "                       third_type_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_input, \\\n",
    "        third, \\\n",
    "        mask_encoder_input, \\\n",
    "        mask_third, \\\n",
    "        encoder_type_ids, \\\n",
    "        third_type_ids = zip(*train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_input = torch.stack(encoder_input)\n",
    "third = torch.stack(third)\n",
    "mask_encoder_input = torch.stack(mask_encoder_input)\n",
    "mask_third = torch.stack(mask_third)\n",
    "encoder_type_ids = torch.stack(encoder_type_ids)\n",
    "third_type_ids = torch.stack(third_type_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = [encoder_input, \\\n",
    "        third, \\\n",
    "        mask_encoder_input, \\\n",
    "        mask_third, \\\n",
    "        encoder_type_ids, \\\n",
    "        third_type_ids]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the processed data. Note that it can be loaded by pickle as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(train_data, \"train_data.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d44d20976fde4e03961243ec71278217",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5161), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "val_data = []\n",
    "\n",
    "encoder_max = 220\n",
    "decoder_max = 40\n",
    "\n",
    "for idx in tqdm.tqdm_notebook(range(len(df_val))):\n",
    "    \n",
    "    content = df_val.iloc[idx]['contents']\n",
    "    origin = df_val.iloc[idx]['origin_titles']\n",
    "    third = df_val.iloc[idx]['modified_titles']\n",
    "    \n",
    "    content =  tokenizer.tokenize(content)\n",
    "    origin = [\"[CLS]\"] + tokenizer.tokenize(origin) + [\"[SEP]\"]\n",
    "    \n",
    "    encoder_input = origin + content\n",
    "    encoder_input = encoder_input[:encoder_max-1] + [\"[SEP]\"]\n",
    "    encoder_type_ids = [0] * len(origin) + [1] * (len(encoder_input) - len(origin))\n",
    "\n",
    "    # modified title\n",
    "    third = [\"[CLS]\"] + tokenizer.tokenize(third)\n",
    "    third = third[:decoder_max-1] + [\"[SEP]\"]\n",
    "\n",
    "    # mask\n",
    "    mask_encoder_input = [1] * len(encoder_input)\n",
    "    mask_third = [1] * len(third)\n",
    "    \n",
    "    # conver to ids\n",
    "    encoder_input = tokenizer.convert_tokens_to_ids(encoder_input)\n",
    "    third = tokenizer.convert_tokens_to_ids(third)\n",
    "    \n",
    "    # padding\n",
    "    encoder_input += [0] * (encoder_max - len(encoder_input))\n",
    "    third += [0] * (decoder_max - len(third))\n",
    "    \n",
    "    mask_encoder_input += [0] * (encoder_max - len(mask_encoder_input))\n",
    "    mask_third += [0] * (decoder_max - len(mask_third))\n",
    "    \n",
    "    encoder_type_ids += [0] * (encoder_max - len(encoder_type_ids))\n",
    "    \n",
    "    # modified type ids\n",
    "    third_type_ids = torch.zeros(len(mask_third)).long()\n",
    "\n",
    "    # turn into tensor\n",
    "    encoder_input = torch.LongTensor(encoder_input)\n",
    "    third = torch.LongTensor(third)\n",
    "    \n",
    "    mask_encoder_input = torch.LongTensor(mask_encoder_input)\n",
    "    mask_third = torch.LongTensor(mask_third)\n",
    "    \n",
    "    encoder_type_ids = torch.LongTensor(encoder_type_ids)\n",
    "    third_type_ids = third_type_ids\n",
    "    \n",
    "    val_data.append((encoder_input, \n",
    "                       third, \n",
    "                       mask_encoder_input, \n",
    "                       mask_third,\n",
    "                       encoder_type_ids,\n",
    "                       third_type_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_input, \\\n",
    "        third, \\\n",
    "        mask_encoder_input, \\\n",
    "        mask_third, \\\n",
    "        encoder_type_ids, \\\n",
    "        third_type_ids = zip(*val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_input = torch.stack(encoder_input)\n",
    "third = torch.stack(third)\n",
    "mask_encoder_input = torch.stack(mask_encoder_input)\n",
    "mask_third = torch.stack(mask_third)\n",
    "encoder_type_ids = torch.stack(encoder_type_ids)\n",
    "third_type_ids = torch.stack(third_type_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_data = [encoder_input, \\\n",
    "        third, \\\n",
    "        mask_encoder_input, \\\n",
    "        mask_third, \\\n",
    "        encoder_type_ids, \\\n",
    "        third_type_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(val_data, \"val_data.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
